{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mf3kOv1YMB5y"
      },
      "source": [
        "# Pharmacokinetics models with TensorFlow Probability\n",
        "\n",
        "Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-rOdskBSMfQN"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rammI_v-S2LS"
      },
      "source": [
        "This notebook demonstrates how to fit a pharmacokinetic model with TensorFlow probability. This includes defining the relevant joint distribution and working through the basic steps of a Bayesian workflow, e.g. prior and posterior predictive checks, diagnostics for the inference, etc.\n",
        "\n",
        "There are three main components when building a pharmacokinetic model:\n",
        "\n",
        "\n",
        "1.   The pharmacokinetics of the system involves solving ordinary differential equations with varying levels of sophistication.\n",
        "2.   We need to describe the treatment the patient undergoes, using a _clinical event schedule_.\n",
        "3.  The data we have comes from multiple patients. To model the heterogeneity between patients, we use a hierarchical model (also termed a population model).\n",
        "\n",
        "We'll first tackle a one compartment model with a first-order absorption from the gut.\n",
        "The ODE describing this system is simple enough to be solved analytically.\n",
        "We'll start with a one-dose model for one patient and build our way up to a population model with an event schedule.\n",
        "\n",
        "Next on the to-do list will be expanding these models to nonlinear ODEs. Examples of ODEs that arise in PK models can be found in this [Stan notebook](https://mc-stan.org/events/stancon2017-notebooks/stancon2017-margossian-gillespie-ode.html).\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wl7S_O0AJYWN"
      },
      "source": [
        "ToDo list:\n",
        "- one cpt model (with analytical solution, numerical integrator an option), one patient, one dose $\\checkmark$\n",
        "- one cpt model, one patient, multiple doses $\\checkmark$\n",
        "- one cpt model, multiple patients, multiple doses $\\checkmark$\n",
        "- Michaelis-Mentis PK model, one patient, one dose\n",
        "- Michaelis-Mentis PK model, multiple patients, multiple doses.\n",
        "- Friberg-Karlsson PKPD model, one patient, multiple doses.\n",
        "- Friberg-Karlsson PKPD model, multiple patients, multiple doses."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kd8S5DK8XqIT"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "tf.executing_eagerly()\n",
        "\n",
        "import numpy as np\n",
        "from matplotlib.pyplot import *\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "matplotlib.pyplot.style.use(\"dark_background\")\n",
        "\n",
        "import jax\n",
        "from jax import random\n",
        "from jax import numpy as jnp\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "\n",
        "# import tensforflow_datasets\n",
        "from inference_gym import using_jax as gym\n",
        "\n",
        "# import tensorflow as tf\n",
        "from tensorflow_probability.python.internal import prefer_static as ps\n",
        "from tensorflow_probability.python.internal import unnest\n",
        "\n",
        "import tensorflow_probability as _tfp\n",
        "tfp = _tfp.substrates.jax\n",
        "tfd = tfp.distributions\n",
        "tfb = tfp.bijectors\n",
        "\n",
        "tfp_np = _tfp.substrates.numpy\n",
        "tfd_np = tfp_np.distributions \n",
        "\n",
        "from jax.experimental.ode import odeint\n",
        "from jax import vmap\n",
        "\n",
        "import arviz as az\n",
        "from tensorflow_probability.python.internal.unnest import get_innermost"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YDoovE217kDx"
      },
      "outputs": [],
      "source": [
        "# Define nested Rhat for one parameter.\n",
        "# Assume for now the indexed parameter is a scalar.\n",
        "def nested_rhat(result_state, num_super_chains, index_param, num_samples,\n",
        "                warmup_length = 0):\n",
        "  state_param = result_state[index_param][\n",
        "                           warmup_length:(warmup_length + num_samples), :, :]\n",
        "  num_samples = state_param.shape[0]\n",
        "  num_chains = state_param.shape[1]\n",
        "  num_sub_chains = num_chains // num_super_chains\n",
        "  \n",
        "  state_param = state_param.reshape(num_samples, -1, num_sub_chains, 1)\n",
        "\n",
        "  mean_chain = np.mean(state_param, axis = (0, 3))\n",
        "  between_chain_var = np.var(mean_chain, axis = 1, ddof = 1)\n",
        "  within_chain_var = np.var(state_param, axis = (0, 3), ddof = 1)\n",
        "  total_chain_var = between_chain_var + np.mean(within_chain_var, axis = 1)\n",
        "\n",
        "  mean_super_chain = np.mean(state_param, axis = (0, 1, 3))\n",
        "  between_super_chain_var = np.var(mean_super_chain, ddof = 1)\n",
        "\n",
        "  return np.sqrt(1 + between_super_chain_var / np.mean(total_chain_var))\n",
        "\n",
        "# WARNING: this is a very poor estimate for ESS, and we shoud note\n",
        "# W / B isn't typically used to estimate ESS.\n",
        "def ess_per_super_chain(nRhat):\n",
        "  return 1 / (np.square(nRhat) - 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sxh29965zc71"
      },
      "source": [
        "## 1 One compartment model with absoprtion from the gut\n",
        "\n",
        "A patient orally takes in a drug, which lands in the gut and is then absorbed into a central compartment (e.g. the blood). This process is described by a differential equation:\n",
        "\\begin{eqnarray*}\n",
        "  y_0' \u0026 = \u0026 -k_0 y_0, \\\\\n",
        "  y_1' \u0026 = \u0026 k_0 y_0 - k_1 y_1,\n",
        "\\end{eqnarray*}\n",
        "with each state corresponding to the drug mass in the gut and the central compartment.\n",
        "This system can be solved analytically for initial conditions $(y_0^I, y_1^I)$ at time $t = 0$:\n",
        "\\begin{eqnarray*}\n",
        "  y_0 \u0026 = \u0026 y_0^I e^{-k_0 t}, \\\\\n",
        "  y_1 \u0026 = \u0026 \\frac{e^{-k_1 t}}{k_0 - k_1} \\left [ y_0^I k_0(1 - e^{(k_1 - k_0)t}) + (k_0 - k_1) y^I_1 \\right ], \n",
        "\\end{eqnarray*}\n",
        "provided $k_0 \\neq k_1$.\n",
        "\n",
        " We can also use on Jax's `odeint` and solve the equation numerically. This will set us up for more complicated problems. The data is noisy observation of $y_1$ (in practice we should use $y_1 / V$ where $V$ is the volume of the central compartment, but I'll omit this for now).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M8brb74_XvoG"
      },
      "outputs": [],
      "source": [
        "# NOTE: need to pass the initial time as the first element of t.\n",
        "t = np.array([0., 0.5, 0.75, 1, 1.25, 1.5, 2, 3, 4, 5, 6])\n",
        "y0 = np.array([100.0, 0.0])\n",
        "\n",
        "theta = np.array([1.5, 0.25])\n",
        "def system(state, time, theta):\n",
        "  k1 = theta[0]\n",
        "  k2 = theta[1]\n",
        "\n",
        "  return jnp.array([\n",
        "    - k1 * state[0]  ,\n",
        "    k1 * state[0] - k2 * state[1]\n",
        "  ])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IcCdeUYS6m2b"
      },
      "outputs": [],
      "source": [
        "use_analytical_sln = True\n",
        "\n",
        "if (use_analytical_sln):\n",
        "  def ode_map(k1, k2):\n",
        "    sln = jnp.exp(- k2 * t) / (k1 - k2) * (y0[0] * k1 * (1 - jnp.exp((k2 - k1) * t)) + (k1 - k2) * y0[1])\n",
        "    return sln[1:]\n",
        "else:\n",
        "  def ode_map(k1, k2):\n",
        "    theta = jnp.array([k1, k2])\n",
        "    return odeint(system, y0, t, theta, mxstep = 1e6)[1:, 1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kK0V0YgaBJ8e"
      },
      "outputs": [],
      "source": [
        "states = ode_map(k1 = theta[0], k2 = theta[1])\n",
        "random.normal(random.PRNGKey(37272710), (states.shape[0],))\n",
        "jnp.log(states)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z_udVycUau-P"
      },
      "source": [
        "## 1.1 Model for one patient recieving a single dose"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lYt8vpauYEv9"
      },
      "outputs": [],
      "source": [
        "# Simulate data\n",
        "states = ode_map(k1 = theta[0], k2 = theta[1])\n",
        "sigma = 0.1\n",
        "log_y = sigma * random.normal(random.PRNGKey(37272710), (states.shape[0],)) \\\n",
        "  + jnp.log(states)\n",
        "\n",
        "y = jnp.exp(log_y)\n",
        "# print(y)\n",
        "\n",
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], states)\n",
        "plot(t[1:], y, 'o')\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3l_LeXPzpBo"
      },
      "source": [
        "### 1.1.1 Run model with TFP\n",
        "The model runs faster on a CPU than a GPU, because of the ODE integrator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v5DB0Yu_Ycy3"
      },
      "outputs": [],
      "source": [
        "model = tfd.JointDistributionSequentialAutoBatched([\n",
        "    # Priors\n",
        "    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = \"k1\"),\n",
        "    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = \"k2\"),\n",
        "    tfd.HalfNormal(scale = 1., name = \"sigma\"),\n",
        "\n",
        "    lambda sigma, k2, k1: (\n",
        "      tfd.LogNormal(loc = jnp.log(ode_map(k1, k2)),\n",
        "                    scale = sigma[..., jnp.newaxis], name = \"y\"))\n",
        "])\n",
        "\n",
        "def target_log_prob_fn(k1, k2, sigma):\n",
        "  return model.log_prob((k1, k2, sigma, y))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_LX3pepcZDT4"
      },
      "outputs": [],
      "source": [
        "num_dimensions = 3\n",
        "def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "  prior_location = jnp.log(jnp.array([1., 1., 1.]))\n",
        "  prior_scale = jnp.array([0.5, 0.5, 0.5])\n",
        "  return jnp.exp(prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_location)\n",
        "\n",
        "# initial_state = initialize((4, ), key = random.PRNGKey(1954))\n",
        "initial_state = model.sample(sample_shape = (4, 1), seed = random.PRNGKey(1954))[:3]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cPCYbGWOCYh1"
      },
      "outputs": [],
      "source": [
        "x = jnp.array(initial_state).reshape(3, 4)\n",
        "print(x[0, :])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WjbHd99uZuqg"
      },
      "outputs": [],
      "source": [
        "# TODO: find a wat to do this when the init is a list!! \n",
        "# Check call to target_log_prob_fn works\n",
        "# target = target_log_prob_fn(initial_state)\n",
        "# print(target)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U4Kj2OBTqbKe"
      },
      "outputs": [],
      "source": [
        "# Prior predictive checks\n",
        "num_prior_samples = 1000\n",
        "*prior_samples, prior_predictive = model.sample(1000, seed = random.PRNGKey(37272709))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krd1-QxmiB4K"
      },
      "outputs": [],
      "source": [
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], y, 'o')\n",
        "plot(t[1:], np.median(prior_predictive, axis = 0), color = 'yellow')\n",
        "plot(t[1:], np.quantile(prior_predictive, q = 0.95, axis = 0), linestyle = ':', color = 'yellow')\n",
        "plot(t[1:], np.quantile(prior_predictive, q = 0.05, axis = 0), linestyle = ':', color = 'yellow')\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sqd_nhqKo8yn"
      },
      "outputs": [],
      "source": [
        "# Implement ChEES transition kernel.\n",
        "init_step_size = 1\n",
        "warmup_length = 1000\n",
        "\n",
        "kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)\n",
        "kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)\n",
        "kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "     kernel, warmup_length, target_accept_prob = 0.75,\n",
        "     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n",
        "\n",
        "def trace_fn(current_state, pkr):\n",
        "  return (\n",
        "    # proxy for divergent transitions\n",
        "    get_innermost(pkr, 'log_accept_ratio') \u003c -1000\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6gdbPSgLpEuw"
      },
      "outputs": [],
      "source": [
        "num_chains = 4\n",
        "\n",
        "mcmc_states, diverged = tfp.mcmc.sample_chain(\n",
        "    num_results = 2000,\n",
        "    current_state = initial_state, \n",
        "    kernel = kernel,\n",
        "    trace_fn = trace_fn,\n",
        "    seed = random.PRNGKey(1954))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krol-v6AEspw"
      },
      "outputs": [],
      "source": [
        "# remove warmup samples\n",
        "for i in range(0, len(mcmc_states)):\n",
        "  mcmc_states[i] = mcmc_states[i][1000:]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pgAXKnQ_TgXQ"
      },
      "outputs": [],
      "source": [
        "# get draws for posterior predictive checks\n",
        "*_, posterior_predictive = model.sample(value = mcmc_states, \n",
        "                                        seed = random.PRNGKey(37272709))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_mjLAI5RvRNp"
      },
      "source": [
        "### 1.1.2 Analyze results\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YEmtNcGERhRu"
      },
      "outputs": [],
      "source": [
        "print(\"Divergent transition(s):\", np.sum(diverged[1000:]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GAFJ-4NK_Xwj"
      },
      "source": [
        "To convert TFP's output to something compatible with Arviz, we'll follow the example in https://jeffpollock9.github.io/bayesian-workflow-with-tfp-and-arviz/."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fYwytRbZONKy"
      },
      "outputs": [],
      "source": [
        "parameter_names = model._flat_resolve_names()\n",
        "\n",
        "az_states = az.from_dict(\n",
        "    prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},\n",
        "    posterior={\n",
        "        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)\n",
        "    },\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "21J3YlDHQtrm"
      },
      "outputs": [],
      "source": [
        "print(az.summary(az_states).filter(items=[\"mean\", \"sd\", \"mcse_sd\", \"hdi_5%\", \n",
        "                                       \"hdi_95%\", \"ess_bulk\", \"ess_tail\", \n",
        "                                       \"r_hat\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nPzebaMXSBTO"
      },
      "outputs": [],
      "source": [
        "axs = az.plot_trace(az_states, combined = False, compact = False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "teFRHkjhSUNu"
      },
      "outputs": [],
      "source": [
        "# TODO: include potential divergent transitions.\n",
        "az.plot_pair(az_states, figsize = (6, 6), kind = 'hexbin', divergences = True);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e0wmo3pOUTtp"
      },
      "outputs": [],
      "source": [
        "ppc_data = posterior_predictive.reshape((4000, 10))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yBfRKltT8fD2"
      },
      "outputs": [],
      "source": [
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], y, 'o')\n",
        "plot(t[1:], np.median(ppc_data, axis = 0), color = 'yellow')\n",
        "plot(t[1:], np.quantile(ppc_data, q = 0.95, axis = 0), linestyle = ':', color = 'yellow')\n",
        "plot(t[1:], np.quantile(ppc_data, q = 0.05, axis = 0), linestyle = ':', color = 'yellow')\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xeKKKn7IDYN5"
      },
      "source": [
        "## 1.2 Clinical event schedule\n",
        "\n",
        "Let's now suppose the patient recieves a bolus dose every $12$ hours for a total of $15$ doses.\n",
        "The first dose is administered at time $t = 0$ and the final dose at time $t = 180$ (hours).\n",
        "We take many observations during the first, second and fourtennth doses. For all other dosing events, we record the drug plasma concentration at the time of the dosing event (i.e. right before the dosing), and then 6 and 12 hours after the dose is administered.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zvL6F0_sELFw"
      },
      "outputs": [],
      "source": [
        "# Construct event times, and identify dosing times (all other times correspond\n",
        "# to measurement events).\n",
        "time_after_dose = np.array([0.083, 0.167, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8])\n",
        "\n",
        "t = np.append(\n",
        "    np.append(np.append(np.append(0., time_after_dose),\n",
        "                          np.append(12., time_after_dose + 12)),\n",
        "               np.linspace(start = 24, stop = 156, num = 12)),\n",
        "               np.append(jnp.append(168., 168. + time_after_dose),\n",
        "               np.array([180, 192])))\n",
        "\n",
        "\n",
        "start_event = np.array([], dtype = int)\n",
        "dosing_time = range(0, 192, 12)\n",
        "\n",
        "# Use dosing events to determine times of integration between\n",
        "# exterior interventions on the system.\n",
        "eps = 1e-4  # hack to deal with some t being slightly offset.\n",
        "for t_dose in dosing_time:\n",
        "  start_event = np.append(start_event, np.where(abs(t - t_dose) \u003c= eps))\n",
        "\n",
        "amt = jnp.array([1000., 0.])\n",
        "n_dose = start_event.shape[0]\n",
        "\n",
        "start_event = np.append(start_event, t.shape[0] - 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-bwrJoy8mRmf"
      },
      "outputs": [],
      "source": [
        "def ode_map (theta, dt, current_state):\n",
        "  k1 = theta[0]\n",
        "  k2 = theta[1]\n",
        "  y0_hat = jnp.exp(- k1 * dt) * current_state[0]\n",
        "  y1_hat = jnp.exp(- k2 * dt) / (k1 - k2) * (current_state[0] * k1 *\\\n",
        "                (1 - jnp.exp((k2 - k1) * dt)) + (k1 - k2) * current_state[1])\n",
        "  return jnp.array([y0_hat, y1_hat])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A5vxTyMDnl__"
      },
      "outputs": [],
      "source": [
        "ode_map(theta, np.array([1, 2, 3]), y0)[1, :]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zgh3JuZda-d4"
      },
      "source": [
        "We now wrap our ODE solver (whehter it be via an analytical solution or a numerical integrator) inside an event schedule handler. For starters, we'll go through the events using a `for` loop. This, it turns out, is fairly inefficient, and we'll later revise this code using `jax.lax.scan`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xPM7h8KllY50"
      },
      "outputs": [],
      "source": [
        "def ode_map_event (theta):\n",
        "  '''\n",
        "  Wrapper around the ODE solver, based on the event schedule.\n",
        "  NOTE: if using the ode integrator, need to adjust the shape of mass.\n",
        "  '''\n",
        "  y_hat = jnp.array([])\n",
        "  current_state = amt\n",
        "  for i in range(0, n_dose):\n",
        "    t_integration = jax.lax.dynamic_slice(t, (start_event[i], ), \n",
        "                           (start_event[i + 1] - start_event[i] + 1, ))\n",
        "    \n",
        "    mass = ode_map(theta, t_integration - t_integration[0], current_state)\n",
        "    # mass = odeint(system, current_state, t_integration,\n",
        "    #               theta, rtol = 1e-6, atol = 1e-6, mxstep = 1e3)\n",
        "\n",
        "    y_hat = jnp.append(y_hat, mass[1, 1:])\n",
        "    current_state = mass[:, mass.shape[1]] + amt\n",
        "  return y_hat\n",
        "\n",
        "y_hat = ode_map_event(theta)\n",
        "log_y_hat = jnp.log(y_hat[1:])\n",
        "\n",
        "sigma = 0.5\n",
        "# NOTE: no observation at time t = 0.\n",
        "log_y = sigma * random.normal(random.PRNGKey(1954), (y_hat.shape[0],)) \\\n",
        "  + jnp.log(y_hat)\n",
        "y_obs = jnp.exp(log_y)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fEYhXQwqbo8h"
      },
      "outputs": [],
      "source": [
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], y_hat)\n",
        "plot(t[1:], y_obs, 'o', markersize = 2)\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l4SpOYIEbvVv"
      },
      "source": [
        "The code above works fine to simulate data but we can do better using `jax.lax.scan`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZCdOhAupsRjd"
      },
      "outputs": [],
      "source": [
        "t_jax = jnp.array(t)\n",
        "amt_vec = np.repeat(0., t.shape[0])\n",
        "amt_vec[start_event] = 1000\n",
        "amt_vec[amt_vec.shape[0] - 1] = 0.\n",
        "amt_vec_jax = jnp.array(amt_vec)\n",
        "\n",
        "# Overwrite definition of ode_map_event.\n",
        "def ode_map_event(theta):\n",
        "  def ode_map_step (current_state, event_index):\n",
        "    dt = t_jax[event_index] - t_jax[event_index - 1]\n",
        "    y_sln = ode_map(theta, dt, current_state)\n",
        "    return (y_sln + jnp.array([amt_vec_jax[event_index], 0.])), y_sln[1,]\n",
        "\n",
        "  (__, yhat) = jax.lax.scan(ode_map_step, amt, np.array(range(1, t.shape[0])))\n",
        "  return yhat\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8zp45oXI2OKY"
      },
      "outputs": [],
      "source": [
        "y_hat = ode_map_event(theta) \n",
        "\n",
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], y_hat)\n",
        "plot(t[1:], y_obs, 'o', markersize = 2)\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Obax1lexpON7"
      },
      "outputs": [],
      "source": [
        "# Remark: using more informative priors helps insure the chains mix\n",
        "# reasonably well. (Could be interesting to examine with nested-rhat\n",
        "# the case where they don't).\n",
        "model = tfd.JointDistributionSequentialAutoBatched([\n",
        "    # Priors\n",
        "    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = \"k1\"),\n",
        "    tfd.LogNormal(loc = jnp.log(.5), scale = 0.25, name = \"k2\"),\n",
        "    tfd.HalfNormal(scale = 1., name = \"sigma\"),\n",
        "\n",
        "    lambda sigma, k2, k1: (\n",
        "      tfd.LogNormal(loc = jnp.log(ode_map_event(jnp.array([k1, k2]))),\n",
        "                    scale = sigma[..., jnp.newaxis], name = \"y_obs\"))\n",
        "])\n",
        "\n",
        "def target_log_prob_fn(k1, k2, sigma):\n",
        "  return model.log_prob((k1, k2, sigma, y_obs))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rm-wSAdaDJ0B"
      },
      "outputs": [],
      "source": [
        "initial_state = model.sample(sample_shape = (4, 1), seed = random.PRNGKey(1954))[:3]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K11arQRyqdYV"
      },
      "outputs": [],
      "source": [
        "# TODO: find a way to test target_log_prob_fn with init as a list\n",
        "# print(initial_state)\n",
        "# target_log_prob_fn(initial_state)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gDQfSMebpv2G"
      },
      "outputs": [],
      "source": [
        "# Implement ChEES transition kernel.\n",
        "init_step_size = 0.1\n",
        "warmup_length = 1000\n",
        "\n",
        "kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)\n",
        "kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)\n",
        "kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "     kernel, warmup_length, target_accept_prob = 0.75,\n",
        "     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ARqxRVdI_tzm"
      },
      "outputs": [],
      "source": [
        "def trace_fn(current_state, pkr):\n",
        "  return (\n",
        "    # proxy for divergent transitions\n",
        "    get_innermost(pkr, 'log_accept_ratio') \u003c -1000,\n",
        "    get_innermost(pkr, 'step_size'),\n",
        "    get_innermost(pkr, 'max_trajectory_length')\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5TB2NpYjqKYZ"
      },
      "outputs": [],
      "source": [
        "num_chains = 4\n",
        "\n",
        "mcmc_states, diverged = tfp.mcmc.sample_chain(\n",
        "    num_results = 2000, \n",
        "    current_state = initial_state, \n",
        "    kernel = kernel,\n",
        "    trace_fn = trace_fn,\n",
        "    seed = random.PRNGKey(1954))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ayEH0FkvAqk1"
      },
      "outputs": [],
      "source": [
        "semilogy(diverged[1], label = \"step size\")\n",
        "semilogy(diverged[2], label = \"max_trajectory length\")\n",
        "legend(loc = \"best\")\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ewxRhNdRKrfR"
      },
      "outputs": [],
      "source": [
        "# remove warmup samples\n",
        "for i in range(0, len(mcmc_states)):\n",
        "  mcmc_states[i] = mcmc_states[i][1000:]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N1YEdr-fEnuA"
      },
      "source": [
        "We'll only look at some essential diagnostics. For more, we can follow the code in the single-dose example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-TCIf1yJELrK"
      },
      "outputs": [],
      "source": [
        "print(\"Divergent transition(s):\", np.sum(diverged[1000:]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NKrEpAZRT8Y4"
      },
      "outputs": [],
      "source": [
        "parameter_names = model._flat_resolve_names()\n",
        "\n",
        "az_states = az.from_dict(\n",
        "    prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},\n",
        "    posterior={\n",
        "        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)\n",
        "    },\n",
        ")\n",
        "\n",
        "print(az.summary(az_states).filter(items=[\"mean\", \"sd\", \"mcse_sd\", \"hdi_3%\", \n",
        "                                       \"hdi_97%\", \"ess_bulk\", \"ess_tail\", \n",
        "                                       \"r_hat\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vFcX2gsjdQNS"
      },
      "outputs": [],
      "source": [
        "# get draws for posterior predictive checks\n",
        "*_, posterior_predictive = model.sample(value = mcmc_states, \n",
        "                                        seed = random.PRNGKey(37272709))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PB9XuVUDP-qv"
      },
      "outputs": [],
      "source": [
        "# ppc_data = posterior_predictive.reshape(1000, 4, 52)\n",
        "\n",
        "# az_data = az.from_dict(\n",
        "#     posterior = dict(x = ppc_data.transpose((1, 0, 2)))\n",
        "# )\n",
        "# print(az.summary(az_data).filter(items=[\"mean\", \"hdi_3%\", \n",
        "#                                        \"hdi_97%\", \"ess_bulk\", \"ess_tail\", \n",
        "#                                        \"r_hat\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RgKRN9XnFBQo"
      },
      "outputs": [],
      "source": [
        "# REMARK: hmmm... the ppc's look odd. Not sure why. Everything else looks fine.\n",
        "figure(figsize = [6, 6])\n",
        "semilogy(t[1:], y_obs, 'o')\n",
        "semilogy(t[1:], np.median(posterior_predictive, axis = (0, 1, 2)), color = 'yellow')\n",
        "semilogy(t[1:], np.quantile(posterior_predictive, q = 0.95, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')\n",
        "semilogy(t[1:], np.quantile(posterior_predictive, q = 0.05, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3u-z8HIDdnIN"
      },
      "source": [
        "## 1.3 Population models\n",
        "\n",
        "We now model data from multiple patients and use a hierarchical model to describe inter-individual heterogeneity. For simplicity, we assume the patients all undergo the same treatment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-q70FSapWt5"
      },
      "source": [
        "### 1.3.1 Simulate data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B5u_7uLf4XM5"
      },
      "outputs": [],
      "source": [
        "# (Code from previous cells, rewritten here to make\n",
        "# section 1.3 self-contained).\n",
        "# TODO: replace this with a function.\n",
        "time_after_dose = np.array([0.083, 0.167, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8])\n",
        "\n",
        "t = np.append(\n",
        "    np.append(np.append(np.append(0., time_after_dose),\n",
        "                          np.append(12., time_after_dose + 12)),\n",
        "               np.linspace(start = 24, stop = 156, num = 12)),\n",
        "               np.append(jnp.append(168., 168. + time_after_dose),\n",
        "               np.array([180, 192])))\n",
        "\n",
        "start_event = np.array([], dtype = int)\n",
        "dosing_time = range(0, 192, 12)\n",
        "\n",
        "# Use dosing events to determine times of integration between\n",
        "# exterior interventions on the system.\n",
        "eps = 1e-4  # hack to deal with some t being slightly offset.\n",
        "for t_dose in dosing_time:\n",
        "  start_event = np.append(start_event, np.where(abs(t - t_dose) \u003c= eps))\n",
        "\n",
        "amt = jnp.array([1000., 0.])\n",
        "n_dose = start_event.shape[0]\n",
        "\n",
        "start_event = np.append(start_event, t.shape[0] - 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mfRo0Ummukam"
      },
      "outputs": [],
      "source": [
        "# NOTE: need to run the first cell under Section 1.2\n",
        "# (Clinical event schedule)\n",
        "\n",
        "n_patients = 100\n",
        "pop_location = jnp.log(jnp.array([1.5, 0.25]))\n",
        "# pop_location = jnp.log(jnp.array([0.5, 1.0]))\n",
        "pop_scale = jnp.array([0.15, 0.35])\n",
        "theta_patient = jnp.exp(pop_scale * random.normal(random.PRNGKey(37272709), \n",
        "                          (n_patients, ) + (2,)) + pop_location)\n",
        "\n",
        "amt = np.array([1000., 0.])\n",
        "amt_patient = np.append(np.repeat(amt[0], n_patients),\n",
        "                        np.repeat(amt[1], n_patients))\n",
        "amt_patient = amt_patient.reshape(2, n_patients)\n",
        "\n",
        "# redfine variables from previous section (in case we only run population model)\n",
        "t_jax = jnp.array(t)\n",
        "amt_vec = np.repeat(0., t.shape[0])\n",
        "amt_vec[start_event] = 1000\n",
        "amt_vec[amt_vec.shape[0] - 1] = 0.\n",
        "amt_vec_jax = jnp.array(amt_vec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4s06cfWJuxBS"
      },
      "source": [
        "We rewrite the ode_map, so that, rather than returning the mass for one patient, it returns the mass across multiple patients. The function `ode_map` now takes in the physiological parameters for all patients, as well as the initial states for each patient.\n",
        "\n",
        "The call to `jax.lax.scan` now takes in an additional argument, `unroll`, which is used to unroll the for loop and make its call on an accelerator more efficient. By default, `unroll = 1` (no unrolling); we observe a major speedup when using `unroll = 10`, and an additional (more minor) speedup when `unroll = 20`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7RwFyOggg_dr"
      },
      "outputs": [],
      "source": [
        "# Rewrite ode_map_event for population case.\n",
        "# TODO: remove 'use_second_axis' hack.\n",
        "def ode_map (theta, dt, current_state, use_second_axis = False):\n",
        "  if (use_second_axis):\n",
        "    k1 = theta[0, :]\n",
        "    k2 = theta[1, :]\n",
        "  else: \n",
        "    k1 = theta[:, 0]\n",
        "    k2 = theta[:, 1]\n",
        "\n",
        "  y0_hat = jnp.exp(- k1 * dt) * current_state[0, :]\n",
        "  y1_hat = jnp.exp(- k2 * dt) / (k1 - k2) * (current_state[0, :] * k1 *\\\n",
        "                (1 - jnp.exp((k2 - k1) * dt)) + (k1 - k2) * current_state[1, :])\n",
        "  return jnp.array([y0_hat, y1_hat])\n",
        "\n",
        "# @jax.jit  # Cannot use jit if function has an IF statement.\n",
        "def ode_map_event(theta, use_second_axis = False):\n",
        "  def ode_map_step (current_state, event_index):\n",
        "    dt = t_jax[event_index] - t_jax[event_index - 1]\n",
        "    y_sln = ode_map(theta, dt, current_state, use_second_axis)\n",
        "    dose = jnp.repeat(amt_vec_jax[event_index], n_patients)\n",
        "    y_after_dose = y_sln + jnp.append(jnp.repeat(amt_vec_jax[event_index], n_patients),\n",
        "                                      jnp.repeat(0., n_patients)).reshape(2, n_patients)\n",
        "    return (y_after_dose, y_sln[1, ])\n",
        "\n",
        "  (__, yhat) = jax.lax.scan(ode_map_step, amt_patient, \n",
        "                            np.array(range(1, t.shape[0])),\n",
        "                            unroll = 20)\n",
        "  return yhat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qdAE6KxseIY5"
      },
      "outputs": [],
      "source": [
        "# Simulate some data\n",
        "y_hat = ode_map_event(theta_patient)\n",
        "\n",
        "sigma = 0.1\n",
        "# NOTE: no observation at time t = 0.\n",
        "log_y = sigma * random.normal(random.PRNGKey(1954), y_hat.shape) \\\n",
        "  + jnp.log(y_hat)\n",
        "y_obs = jnp.exp(log_y)\n",
        "\n",
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], y_hat)\n",
        "plot(t[1:], y_obs, 'o', markersize = 2)\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GcgditOnpfOJ"
      },
      "source": [
        "### 1.3.2 Fit the model with TFP "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ypgnnwm0CxKl"
      },
      "source": [
        "This is an adaptation of the previous model, except we're now only working with parameters on the unconstrained scale. This makes it easier for HMC and it is good practice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P_N-3eRl6HNa"
      },
      "outputs": [],
      "source": [
        "pop_model = tfd.JointDistributionSequentialAutoBatched([\n",
        "    # tfd.LogNormal(loc = jnp.log(1.), scale = 0.25, name = \"k1_pop\"),\n",
        "    # tfd.LogNormal(loc = jnp.log(0.3), scale = 0.1, name = \"k2_pop\"),\n",
        "    # tfd.Normal(loc = jnp.log(1.), scale = 0.25, name = \"log_k1_pop\"),\n",
        "    tfd.Normal(loc = jnp.log(1.), scale = 0.1, name = \"log_k1_pop\"),\n",
        "    tfd.Normal(loc = jnp.log(0.3), scale = 0.1, name = \"log_k2_pop\"),\n",
        "    tfd.Normal(loc = jnp.log(0.15), scale = 0.1, name = \"log_scale_k1\"),\n",
        "    tfd.Normal(loc = jnp.log(0.35), scale = 0.1, name = \"log_scale_k2\"),\n",
        "    # tfd.HalfNormal(scale = 1., name = \"sigma\"),\n",
        "    tfd.Normal(loc = -1., scale = 1., name = \"log_sigma\"),\n",
        "\n",
        "    # non-centered parameterization for hierarchy\n",
        "    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),\n",
        "                               scale = jnp.ones(n_patients),\n",
        "                               name = \"eta_k1\"),\n",
        "                    reinterpreted_batch_ndims = 1),\n",
        "    \n",
        "    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),\n",
        "                               scale = jnp.ones(n_patients),\n",
        "                               name = \"eta_k2\"),\n",
        "                    reinterpreted_batch_ndims = 1),\n",
        "\n",
        "    lambda eta_k2, eta_k1, log_sigma, log_scale_k2, log_scale_k1,\n",
        "           log_k2_pop, log_k1_pop: (\n",
        "      tfd.Independent(tfd.LogNormal(\n",
        "          loc = jnp.log(\n",
        "              ode_map_event(theta = jnp.array([\n",
        "                  jnp.exp(log_k1_pop[..., jnp.newaxis] + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),\n",
        "                  jnp.exp(log_k2_pop[..., jnp.newaxis] + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),\n",
        "                  use_second_axis = True)),\n",
        "          scale = jnp.exp(log_sigma[..., jnp.newaxis]), name = \"y_obs\")))\n",
        "\n",
        "    # lambda eta_k2, eta_k1, sigma, log_scale_k2, log_scale_k1,\n",
        "    #        k2_pop, k1_pop: (\n",
        "    #   tfd.Independent(tfd.LogNormal(\n",
        "    #       loc = jnp.log(\n",
        "    #           ode_map_event(theta = jnp.array(\n",
        "    #           [jnp.exp(jnp.log(k1_pop[..., jnp.newaxis]) + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),\n",
        "    #            jnp.exp(jnp.log(k2_pop[..., jnp.newaxis]) + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),\n",
        "    #            use_second_axis = True)),\n",
        "    #       scale = sigma[..., jnp.newaxis], name = \"y_obs\")))\n",
        "])\n",
        "\n",
        "def pop_target_log_prob_fn(log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,\n",
        "                           log_sigma, eta_k1, eta_k2):\n",
        "  return pop_model.log_prob((log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,\n",
        "                            log_sigma, eta_k1, eta_k2, y_obs))\n",
        "  # CHECK -- do we need to parenthesis?\n",
        "\n",
        "\n",
        "\n",
        "# def pop_target_log_prob_fn(k1_pop, k2_pop, log_scale_k1, log_scale_k2,\n",
        "#                            sigma, eta_k1, eta_k2):\n",
        "#   return pop_model.log_prob((k1_pop, k2_pop, log_scale_k1, log_scale_k2,\n",
        "#                            sigma, eta_k1, eta_k2, y_obs))\n",
        "\n",
        "def pop_target_log_prob_fn_flat(x):\n",
        "  k1_pop = x[:, 0]\n",
        "  k2_pop = x[:, 1]\n",
        "  sigma = x[:, 2]\n",
        "  log_scale_k1 = x[:, 3]\n",
        "  log_scale_k2 = x[:, 4]\n",
        "  eta_k1 = x[:, 5:(5 + n_patients)]\n",
        "  eta_k2 = x[:, (5 + n_patients):(5 + 2 * n_patients)]\n",
        "\n",
        "  return pop_model.log_prob((k1_pop, k2_pop, log_scale_k1, log_scale_k2,\n",
        "                           sigma, eta_k1, eta_k2, y_obs))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7x630Zb1DHuF"
      },
      "source": [
        "If we want to run many chains in parallel and use $n\\hat R$ (nested $\\hat R$), we need to specify the number of chains and the number of super chains.\n",
        "The number of super chains determined the numbers of distinct starting point, seeing within each super chain, each chain starts at the same location. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-R1KuCCyXDVY"
      },
      "outputs": [],
      "source": [
        "# Sample initial states from prior\n",
        "num_chains = 128\n",
        "num_super_chains = 4  #  num_chains  #  128\n",
        "\n",
        "n_parm = 5 + 2 * n_patients\n",
        "initial_state_raw = pop_model.sample(sample_shape = (num_super_chains, 1),\\\n",
        "                                     seed = random.PRNGKey(37272710))[:7]\n",
        "\n",
        "# QUESTION: does this assignment create a pointer?\n",
        "initial_state = initial_state_raw\n",
        "\n",
        "for i in range(0, len(initial_state_raw)):\n",
        "  initial_state[i] = np.repeat(initial_state_raw[i],\n",
        "                               num_chains // num_super_chains, axis = 0)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0J3pKThK1yE5"
      },
      "source": [
        "Some care is required when setting the tuning parameters for ChEES-HMC, in particular the initial step size. In the [ChEES-HMC paper](http://proceedings.mlr.press/v130/hoffman21a/hoffman21a.pdf), the following proceudre is used: \"Initial step sizes were chosen by repeatedly halving the step size (starting from a consistently too-large value of 1.0) until an HMC proposal with a single leapfrog step achieved a harmonic-mean acceptance probability of at least 0.5.\"\n",
        "\n",
        "TODO: implement this.\n",
        "\n",
        "For now we note that an \"appropriate\" initial step size depends on the number chains (for reasons I don't quite understand...)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aXS4RkqKr_v8"
      },
      "outputs": [],
      "source": [
        "# Implement ChEES transition kernel. Increase the target acceptance rate\n",
        "# to avoid divergent transitions.\n",
        "# NOTE: increasing the target acceptance probability can lead to poor performance.\n",
        "init_step_size = 0.001  # CHECK -- how to best tune this?\n",
        "warmup_length = 1000 # 1000\n",
        "\n",
        "kernel = tfp.mcmc.HamiltonianMonteCarlo(pop_target_log_prob_fn, \n",
        "                                        step_size = init_step_size, \n",
        "                                        num_leapfrog_steps = 10)\n",
        "kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)\n",
        "kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "     kernel, warmup_length, target_accept_prob = 0.75,\n",
        "     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n",
        "\n",
        "def trace_fn(current_state, pkr):\n",
        "  return (\n",
        "    # proxy for divergent transitions\n",
        "    get_innermost(pkr, 'log_accept_ratio') \u003c -1000,\n",
        "    get_innermost(pkr, 'step_size'),\n",
        "    get_innermost(pkr, 'max_trajectory_length')\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2hEuLZccsdY4"
      },
      "outputs": [],
      "source": [
        "mcmc_states, diverged = tfp.mcmc.sample_chain(\n",
        "    num_results = warmup_length + 1000,\n",
        "    current_state = initial_state,\n",
        "    kernel = kernel,\n",
        "    trace_fn = trace_fn,\n",
        "    seed = random.PRNGKey(1954))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s_t77idfg9SW"
      },
      "outputs": [],
      "source": [
        "# Remark: somehow modifying mcmc_states still modifies\n",
        "# mcmc_states_raw.\n",
        "mcmc_states_raw = mcmc_states"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E1CCgq0jppah"
      },
      "source": [
        "### 1.3.3 Traditional diagnostics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ura3AS9Nv3EB"
      },
      "outputs": [],
      "source": [
        "# remove warmup samples\n",
        "# NOTE: not a good idea. It's better to store all the states.\n",
        "if False:\n",
        "  for i in range(0, len(mcmc_states)):\n",
        "    mcmc_states[i] = mcmc_states_raw[i][warmup_length:]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Alb3e0QCd-b"
      },
      "outputs": [],
      "source": [
        "semilogy(diverged[1], label = \"step size\")\n",
        "semilogy(diverged[2], label = \"max_trajectory length\")\n",
        "legend(loc = \"best\")\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JgsVYXdSn0XE"
      },
      "outputs": [],
      "source": [
        "mcmc_states[0].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6F-lSkMjd86x"
      },
      "outputs": [],
      "source": [
        "# Use this to search for points where the the step size changes\n",
        "# dramatically and divergences that might be happening there.\n",
        "if False:\n",
        "  index_l = 219  # 225\n",
        "  index_u = index_l + 1  # 235\n",
        "  print(\"Max L:\" , diverged[2][index_l:index_u])\n",
        "  print(\"Divergence:\", np.sum(diverged[0][index_l:index_u]),\n",
        "        \"at\", np.where(diverged[0][index_l:index_u] == 1))\n",
        "\n",
        "  chain = 0\n",
        "  eta1_state = mcmc_states[5][index_l, chain, :] *\\\n",
        "    mcmc_states[2][index_l, chain, 0] + mcmc_states[0][index_l, chain, 0]\n",
        "  eta2_state = mcmc_states[6][index_l, chain, :] *\\\n",
        "    mcmc_states[3][index_l, chain, 0] + mcmc_states[1][index_l, chain, 0]\n",
        "\n",
        "  k0_state = np.exp(eta1_state)\n",
        "  k1_state = np.exp(eta2_state) \n",
        "  print(k0_state - k1_state)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ju1sBMuVX6Nn"
      },
      "outputs": [],
      "source": [
        "print(\"Divergent transition(s):\", np.sum(diverged[0][warmup_length:]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NCAQdcFQYPz9"
      },
      "outputs": [],
      "source": [
        "# NOTE: the last parameter is an 'x': not sure where this comes from...\n",
        "parameter_names = pop_model._flat_resolve_names()[:-1]\n",
        "\n",
        "az_states = az.from_dict(\n",
        "    #prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},\n",
        "    posterior={\n",
        "        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)\n",
        "    },\n",
        ")\n",
        "\n",
        "print(az.summary(az_states).filter(items=[\"mean\", \"sd\", \"mcse_sd\", \"hdi_3%\", \n",
        "                                       \"hdi_97%\", \"ess_bulk\", \"ess_tail\", \n",
        "                                       \"r_hat\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PwRaHRUIZBKU"
      },
      "outputs": [],
      "source": [
        "# Only plot the population parameters.\n",
        "axs = az.plot_trace(az_states, combined = False, compact = False,\n",
        "                    var_names = parameter_names[:5])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Edxek3KkHgr0"
      },
      "outputs": [],
      "source": [
        "# posterior predictive checks\n",
        "# NOTE: for 100 patients, running this exhausts memory\n",
        "*_, posterior_predictive = pop_model.sample(value = mcmc_states, \n",
        "                                        seed = random.PRNGKey(37272709))\n",
        "ppc_data = posterior_predictive.reshape(1000 * num_chains, 52, n_patients)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kobBLgFmIHHR"
      },
      "outputs": [],
      "source": [
        "# NOTE: unclear why the confidence interval is so small...\n",
        "fig, axes = subplots(n_patients, 1, figsize=(8, 4 * n_patients))\n",
        "\n",
        "for i in range(0, n_patients):\n",
        "  patient_ppc = posterior_predictive[:, :, :, :, i]\n",
        "  axes[i].semilogy(t[1:], y_obs[:, i], 'o')\n",
        "  axes[i].semilogy(t[1:], np.median(patient_ppc, axis = (0, 1, 2)), color = 'yellow')\n",
        "  axes[i].semilogy(t[1:], np.quantile(patient_ppc, q = 0.95, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')\n",
        "  axes[i].semilogy(t[1:], np.quantile(patient_ppc, q = 0.05, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0EpU8FfApLKM"
      },
      "source": [
        "### 1.3.3 Diagnostic using $n \\hat R$."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4lbjEnANvPsK"
      },
      "source": [
        "For starters, let's examine estimates in the short regime, i.e. using only the first few iterations from each chain. We'll focus on $\\log k_{1,\\text{pop}}$ which seems to have the most difficult expectation value to estimate (given it's relatively low ESS)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4KS_r5r6vOqb"
      },
      "outputs": [],
      "source": [
        "# Assumes mcmc_states contains all the samples (including warmup)\n",
        "parameter_index = 0\n",
        "num_samples = 500\n",
        "mc_mean = np.mean(mcmc_states[parameter_index][\n",
        "                  warmup_length:(warmup_length + num_samples), :, :])\n",
        "\n",
        "print(\"Mean:\", mc_mean)\n",
        "print(\"Estimated squared error:\",\n",
        "      np.square(mc_mean -\n",
        "                np.mean(mcmc_states[parameter_index][warmup_length:, :, :])))\n",
        "print(\"Upper bound on expected squared error for one iteration:\",\n",
        "      np.var(mcmc_states[0]) / num_chains)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lQajSPNT7ggR"
      },
      "outputs": [],
      "source": [
        "nRhat = nested_rhat(result_state = mcmc_states, \n",
        "                    num_super_chains = num_super_chains,\n",
        "                    index_param = parameter_index, \n",
        "                    num_samples = num_samples,\n",
        "                    warmup_length = warmup_length)\n",
        "\n",
        "print(\"num_samples:\", num_samples)\n",
        "print(\"nRhat:\", nRhat)\n",
        "print(\"Rhat:\",\n",
        "      tfp.mcmc.potential_scale_reduction(\n",
        "          mcmc_states[0][warmup_length:(num_samples + warmup_length), :, :]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eKflNcfnC4l1"
      },
      "source": [
        "## 2 Michaelis-Menten pharmacokinetics (Incomplete)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AlJ32QWJ5y5"
      },
      "source": [
        "Nonlinear PK model with absorption from the gut.\n",
        "\n",
        "\\begin{eqnarray*}\n",
        "  y_0' \u0026 = \u0026 - k_a y_0 \\\\\n",
        "  y_1' \u0026 = \u0026 k_a y_0 - \\frac{V_m C}{K_m + C},\n",
        "\\end{eqnarray*}\n",
        "whwre $C = y_1 / V$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WvvAVKmyD7OJ"
      },
      "outputs": [],
      "source": [
        "t = np.array([0.0, 0.5, 0.75, 1, 1.25, 1.5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])\n",
        "y0 = np.array([100.0, 0.0])\n",
        "theta = np.array([0.5, 27, 10, 14])\n",
        "\n",
        "def system(state, time, theta):\n",
        "  ka = theta[0]\n",
        "  V = theta[1]\n",
        "  Vm = theta[2]\n",
        "  Km = theta[3]\n",
        "  C = state[1] / V\n",
        "\n",
        "  return jnp.array([\n",
        "    - ka * state[0],\n",
        "    ka * state[0] - Vm * C  / (Km + C)            \n",
        "  ])\n",
        "\n",
        "states = odeint(system, y0, t, theta, mxstep = 1000)\n",
        "sigma = 0.5\n",
        "log_y = sigma * random.normal(random.PRNGKey(37272709), (states.shape[0] - 1,)) \\\n",
        "  + jnp.log(states[1:, 1])\n",
        "\n",
        "y = jnp.exp(log_y)\n",
        "\n",
        "figure(figsize = [6, 6])\n",
        "plot(t[1:], states[1:, 1])\n",
        "plot(t[1:], y, 'o');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZJDzcZ7sE6Ec"
      },
      "outputs": [],
      "source": [
        "def ode_map(ka, V, Vm, Km):\n",
        "  theta = jnp.array([ka, V, Vm, Km])\n",
        "  return odeint(system, y0, t, theta, mxstep = 1e3)[1:, 1]\n",
        "\n",
        "model = tfd.JointDistributionSequentialAutoBatched([\n",
        "    # Priors\n",
        "    tfd.LogNormal(loc = jnp.log(1), scale = 0.5, name = \"ka\"),\n",
        "    tfd.LogNormal(loc = jnp.log(35), scale = 0.5, name = \"V\"),\n",
        "    tfd.LogNormal(loc = jnp.log(10), scale = 0.5, name = \"Vm\"),\n",
        "    tfd.LogNormal(loc = jnp.log(2.5), scale = 1, name = \"Km\"),\n",
        "    tfd.HalfNormal(scale = 1., name = \"sigma\"),\n",
        "\n",
        "    # Likelihood (TODO: divide location by volume to get concentration)\n",
        "    lambda sigma, Km, Vm, V, ka: (\n",
        "      tfd.LogNormal(loc = jnp.log(ode_map(ka, V, Vm, Km) / V),\n",
        "                   scale = sigma[..., jnp.newaxis], name = \"y\"))\n",
        "])\n",
        "\n",
        "def target_log_prob_fn(x):\n",
        "  ka = x[:, 0]\n",
        "  V = x[:, 1]\n",
        "  Vm = x[:, 2]\n",
        "  Km = x[:, 3]\n",
        "  sigma = x[:, 4]\n",
        "  return model.log_prob((ka, V, Vm, Km, sigma, y))\n",
        "\n",
        "num_dimensions = 5\n",
        "def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "  prior_location = jnp.log(jnp.array([1.5, 35, 10, 2.5, 0.5]))\n",
        "  prior_scale = jnp.array([3, 0.5, 0.5, 3, 1.])\n",
        "  return jnp.exp(prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_location)\n",
        "\n",
        "initial_state = initialize((4, ), key = random.PRNGKey(1954))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lqbZweOXF0cE"
      },
      "outputs": [],
      "source": [
        "# Test target probability density can be computed\n",
        "target = target_log_prob_fn(initial_state)\n",
        "print(target)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rcrUavhHJAke"
      },
      "outputs": [],
      "source": [
        "# Implement ChEES transition kernel.\n",
        "init_step_size = 1\n",
        "warmup_length = 250\n",
        "\n",
        "kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)\n",
        "kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)\n",
        "kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "     kernel, warmup_length, target_accept_prob = 0.75,\n",
        "     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FyoGwFy9Gsb7"
      },
      "outputs": [],
      "source": [
        "num_chains = 4\n",
        "\n",
        "# NOTE: It takes 29 seconds to run one iteration. So running 500 iterations\n",
        "# would take ~4 hours :(\n",
        "# QUESTION: why does JAX struggle so much to solve this type of problem??\n",
        "result = tfp.mcmc.sample_chain(\n",
        "    num_results = 1, \n",
        "    current_state = initial_state, \n",
        "    kernel = kernel,\n",
        "    seed = random.PRNGKey(1954))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z5WsAlMA0ZBl"
      },
      "source": [
        "## Draft Code\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cwZzVUChAus5"
      },
      "outputs": [],
      "source": [
        "R = 1.62\n",
        "1 / (R * R - 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CUk2qgx5GtLe"
      },
      "outputs": [],
      "source": [
        "a = np.array(range(4, 1024, 4))\n",
        "d = np.repeat(6., len(a))\n",
        "\n",
        "# Two optimization solutions, solving quadratic equations (+ / -)\n",
        "# Remark: + solution gives a negative upper-bound for delta_u\n",
        "alpha_1 = 2 * a + d / 2 - np.sqrt(np.square(2 * a + d / 2) - 2 * a)\n",
        "alpha_2 = a - alpha_1\n",
        "delta_u = (np.square(alpha_1 + d / 2) / (alpha_1 * alpha_2)) / 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yi6HD2znsYIK"
      },
      "outputs": [],
      "source": [
        "eps = 0.01\n",
        "delta = np.square(1 + eps) - 1\n",
        "print(delta)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vyqywrFuq6hH"
      },
      "outputs": [],
      "source": [
        "semilogy(a / d, delta_u)\n",
        "hlines(delta, (a / d)[0], (a / d)[len(a) - 1], linestyles = '--',\n",
        "      label =  \"delta for 1.01 threshold\")\n",
        "xlabel(\"a / d\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KOxZiFUGidI3"
      },
      "outputs": [],
      "source": [
        "semilogy(a / d, alpha_1 / a, label = \"alpha_1\")\n",
        "semilogy(a / d, alpha_2 / a, label = \"alpha_2\")\n",
        "legend(loc = 'best')\n",
        "xlabel(\"a / d\")\n",
        "ylabel(\"alpha\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RxO-KhRocA13"
      },
      "outputs": [],
      "source": [
        "aindex_location = np.where(a / d == 100)\n",
        "print(index_location)\n",
        "print(delta_u[index_location])\n",
        "delta"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-ItlDq-o0dBd"
      },
      "outputs": [],
      "source": [
        "pop_model = tfd.JointDistributionSequentialAutoBatched([\n",
        "    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = \"k1_pop\"),\n",
        "    tfd.LogNormal(loc = jnp.log(.5), scale = 0.25, name = \"k2_pop\"),\n",
        "    tfd.Normal(loc = jnp.log(0.5), scale = 1., name = \"log_scale_k1\"),\n",
        "    tfd.Normal(loc = jnp.log(0.5), scale = 1., name = \"log_scale_k2\"),\n",
        "    tfd.HalfNormal(scale = 1., name = \"sigma\"),\n",
        "\n",
        "    # non-centered parameterization for hierarchy\n",
        "    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),\n",
        "                               scale = jnp.ones(n_patients),\n",
        "                               name = \"eta_k1\"),\n",
        "                    reinterpreted_batch_ndims = 1),\n",
        "    \n",
        "    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),\n",
        "                               scale = jnp.ones(n_patients),\n",
        "                               name = \"eta_k2\"),\n",
        "                    reinterpreted_batch_ndims = 1),\n",
        "\n",
        "    lambda eta_k2, eta_k1, sigma, log_scale_k2, log_scale_k1,\n",
        "           k2_pop, k1_pop: (\n",
        "      tfd.Independent(tfd.LogNormal(\n",
        "          loc = jnp.log(\n",
        "              ode_map_event(theta = jnp.array(\n",
        "              [jnp.exp(jnp.log(k1_pop[..., jnp.newaxis]) + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),\n",
        "               jnp.exp(jnp.log(k2_pop[..., jnp.newaxis]) + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),\n",
        "               use_second_axis = True)),\n",
        "          scale = sigma[..., jnp.newaxis], name = \"y_obs\")))\n",
        "])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6zYQufGa2xtm"
      },
      "outputs": [],
      "source": [
        "num_hyper = 5\n",
        "num_dimensions = num_hyper + 2 * n_patients\n",
        "\n",
        "def pop_initialize(shape, key = random.PRNGKey(37272710)) :\n",
        "  # init for k1_pop, k2_pop, and sigma\n",
        "  hyper_prior_location = jnp.array([jnp.log(1.5), jnp.log(0.25), 0.])\n",
        "  hyper_prior_scale = jnp.array([0.5, 0.1, 0.5])\n",
        "  init_hyper_param = jnp.exp(hyper_prior_scale * random.normal(key, shape + \\\n",
        "                             (3, )) + hyper_prior_location)\n",
        "\n",
        "  # init for log_scale_k1 and log_scale_k2\n",
        "  scale_prior_location = jnp.array([-1., -1.])\n",
        "  scale_prior_scale = jnp.array([0.25, 0.25])\n",
        "  init_scale = scale_prior_scale * random.normal(key, shape + (2, )) +\\\n",
        "    scale_prior_location\n",
        "\n",
        "  # inits for the etas\n",
        "  init_eta = random.normal(key, shape + (2 * n_patients, ))\n",
        "  return jnp.append(jnp.append(init_hyper_param, init_scale, axis = 1), \n",
        "                    init_eta, axis = 1)\n",
        "\n",
        "initial_state = pop_initialize((4, ))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T3Zh9soc2t-M"
      },
      "outputs": [],
      "source": [
        "initial_list = [initial_state[:, 0],  # k1_pop\n",
        "                initial_state[:, 1],  # k2_pop\n",
        "                initial_state[:, 2],  # log_scale_k1\n",
        "                initial_state[:, 3],  # log_scale_k2\n",
        "                initial_state[:, 4],  # sigma\n",
        "                initial_state[:, 5:(5 + n_patients)],                    # eta_k1\n",
        "                initial_state[:, (5 + n_patients):(5 + 2 * n_patients)]  # eta_k2         \n",
        "                ]"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "_mjLAI5RvRNp",
        "eKflNcfnC4l1"
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
        "kind": "private"
      },
      "name": "PKmodels.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1RrCAGXR3VT2pJEqMbTYWPNQSnyIWraV4",
          "timestamp": 1632344208982
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
